{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aco import ACO\n",
    "import numpy as np\n",
    "import torch\n",
    "import logging\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "N_ANTS = 10\n",
    "N_ITERATIONS = [1, 10, 20, 30, 40, 50]\n",
    "\n",
    "def heuristics_reevo(item_attr1: np.ndarray, item_attr2: np.ndarray) -> np.ndarray:\n",
    "    n, m = item_attr2.shape\n",
    "    \n",
    "    # Normalize item_attr1 and item_attr2\n",
    "    item_attr1_norm = (item_attr1 - np.min(item_attr1)) / (np.max(item_attr1) - np.min(item_attr1))\n",
    "    item_attr2_norm = (item_attr2 - np.min(item_attr2)) / (np.max(item_attr2) - np.min(item_attr2))\n",
    "    \n",
    "    # Calculate the average value of normalized attribute 1\n",
    "    avg_attr1 = np.mean(item_attr1_norm)\n",
    "    \n",
    "    # Calculate the maximum value of normalized attribute 2 for each item\n",
    "    max_attr2 = np.max(item_attr2_norm, axis=1)\n",
    "    \n",
    "    # Calculate the sum of normalized attribute 2 for each item\n",
    "    sum_attr2 = np.sum(item_attr2_norm, axis=1)\n",
    "    \n",
    "    # Calculate the standard deviation of normalized attribute 2 for each item\n",
    "    std_attr2 = np.std(item_attr2_norm, axis=1)\n",
    "    \n",
    "    # Calculate the heuristics based on a combination of normalized attributes 1 and 2,\n",
    "    # while considering the average, sum, and standard deviation of normalized attribute 2\n",
    "    heuristics = (item_attr1_norm / max_attr2) * (item_attr1_norm / avg_attr1) * (item_attr1_norm / sum_attr2) * (1 / std_attr2)\n",
    "    \n",
    "    # Normalize the heuristics to a range of [0, 1]\n",
    "    heuristics = (heuristics - np.min(heuristics)) / (np.max(heuristics) - np.min(heuristics))\n",
    "    \n",
    "    return heuristics\n",
    "\n",
    "\n",
    "def solve(prize: np.ndarray, weight: np.ndarray):\n",
    "    n, m = weight.shape\n",
    "    heu = heuristics_reevo(prize.copy(), weight.copy()) + 1e-9\n",
    "    assert heu.shape == (n,)\n",
    "    heu[heu < 1e-9] = 1e-9\n",
    "    aco = ACO(torch.from_numpy(prize), torch.from_numpy(weight), torch.from_numpy(heu), N_ANTS)\n",
    "    \n",
    "    results = []\n",
    "    for i in range(len(N_ITERATIONS)):\n",
    "        if i == 0:\n",
    "            obj, _ = aco.run(N_ITERATIONS[i])\n",
    "        else:\n",
    "            obj, _ = aco.run(N_ITERATIONS[i] - N_ITERATIONS[i-1])\n",
    "        results.append(obj.item())\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[*] Average for 100, 1 iterations: 16.183777390580715\n",
      "[*] Average for 100, 10 iterations: 16.95788758133592\n",
      "[*] Average for 100, 20 iterations: 17.068556032421405\n",
      "[*] Average for 100, 30 iterations: 17.113623508631083\n",
      "[*] Average for 100, 40 iterations: 17.113623508631083\n",
      "[*] Average for 100, 50 iterations: 17.137104057227354\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for problem_size in [100]:\n",
    "    dataset_path = f\"dataset/val{problem_size}_dataset.npz\"\n",
    "    dataset = np.load(dataset_path)\n",
    "    prizes, weights = dataset['prizes'], dataset['weights']\n",
    "    n_instances = prizes.shape[0]\n",
    "    logging.info(f\"[*] Evaluating {dataset_path}\")\n",
    "\n",
    "    objs = []\n",
    "    for i, (prize, weight) in enumerate(zip(prizes, weights)):\n",
    "        obj = solve(prize, weight)\n",
    "        objs.append(obj)\n",
    "    \n",
    "    # Average objective value for all instances\n",
    "    mean_obj = np.mean(objs, axis=0)\n",
    "    for i, obj in enumerate(mean_obj):\n",
    "        print(f\"[*] Average for {problem_size}, {N_ITERATIONS[i]} iterations: {obj}\")\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
